# Copyright 2019-2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
#     or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Dict

import numpy as np
import pandas as pd
import pytest
from pgmpy.models import BayesianNetwork as BayesianModel
from sklearn.datasets import load_iris

from causalnex.discretiser import Discretiser
from causalnex.network import BayesianNetwork
from causalnex.structure import StructureModel
from causalnex.structure.notears import from_pandas


# Ignoring limit of 1000 lines per module, since this module contains test sets.
# pylint: disable=C0302
@pytest.fixture
def train_model() -> StructureModel:
    """
    This Bayesian Model structure will be used in all tests, and all fixtures will adhere to this structure.

    Cause-only nodes: [d, e]
    Effect-only nodes: [a, c]
    Cause / Effect nodes: [b]

            d
         ↙  ↓  ↘
        a ← b → c
            ↑  ↗
            e
    """
    model = StructureModel()
    model.add_edges_from(
        [
            ("b", "a"),
            ("b", "c"),
            ("d", "a"),
            ("d", "c"),
            ("d", "b"),
            ("e", "c"),
            ("e", "b"),
        ]
    )
    return model


@pytest.fixture
def train_model_idx(train_model) -> BayesianModel:
    """
    This Bayesian model is identical to the train_model() fixture, with the exception that node names
    are integers from zero to 1, mapped by:

    {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4}
    """
    model = BayesianModel()
    idx_map = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4}
    model.add_edges_from([(idx_map[u], idx_map[v]) for u, v in train_model.edges])
    return model


@pytest.fixture
def train_data() -> pd.DataFrame:
    """
    Training data for testing Bayesian Networks. There are 98 samples, with 5 columns:

    - a: {"a", "b", "c", "d"}
    - b: {"x", "y", "z"}
    - c: 0.0 - 100.0
    - d: Boolean
    - e: Boolean

    This data was generated by constructing the Bayesian Model train_model(), and then sampling
    from this structure. Since e and d are both independent of all other nodes, these were sampled first for
    each row (form their respective pre-defined distributions). This then allows the sampling of all further
    variables based on their conditional dependencies.

    The approximate distributions used to sample from can be viewed by inspecting train_data_cpds().

    """

    data_arr = [
        ["a", "x", 73.78658346945414, False, False],
        ["d", "x", 12.765853213346603, False, False],
        ["c", "y", 22.43657132589221, False, False],
        ["a", "x", 4.267744937038964, False, False],
        ["b", "x", 62.87087344904927, False, False],
        ["c", "x", 31.55295196889971, False, False],
        ["a", "x", 37.403388911083965, False, False],
        ["b", "x", 63.171968604247155, False, False],
        ["d", "x", 11.140539452118263, False, False],
        ["d", "x", 0.1555338799942385, True, False],
        ["c", "x", 9.269926225399187, False, True],
        ["b", "z", 75.38846241765208, True, True],
        ["c", "z", 33.10212378889936, False, True],
        ["b", "z", 57.04657630213301, True, True],
        ["b", "x", 72.03855905511072, True, False],
        ["c", "x", 5.106018765399956, False, False],
        ["c", "z", 5.802617702038839, False, True],
        ["c", "x", 17.22538330530506, False, False],
        ["a", "y", 87.05395007052729, False, False],
        ["d", "y", 19.09989481093348, False, False],
        ["c", "x", 4.313272835124353, True, False],
        ["b", "x", 13.660704178900938, True, True],
        ["b", "x", 7.693287813764131, False, False],
        ["c", "y", 32.791770073523246, False, False],
        ["c", "y", 12.039098492465282, False, False],
        ["a", "x", 51.97718339128754, False, False],
        ["d", "x", 8.393970656769238, False, False],
        ["a", "x", 0.3610815726384886, False, False],
        ["a", "y", 35.31788713900731, True, False],
        ["b", "x", 35.84702992379284, False, True],
        ["c", "y", 32.872350426703356, True, False],
        ["a", "x", 21.218746335586868, False, True],
        ["b", "y", 71.5495653029006, True, False],
        ["c", "x", 15.393846082097575, False, False],
        ["d", "y", 4.514559208625406, False, False],
        ["d", "x", 0.704928173400301, False, False],
        ["c", "y", 34.10829794112354, True, False],
        ["d", "x", 6.84602512195673, False, False],
        ["b", "y", 25.43743439885204, False, False],
        ["d", "x", 7.544831467091971, False, False],
        ["d", "x", 13.923699372025073, False, False],
        ["b", "x", 21.493005760070915, False, False],
        ["a", "x", 41.353977640369436, False, False],
        ["c", "z", 10.015835005248583, True, True],
        ["c", "z", 29.40115954319444, False, True],
        ["c", "x", 17.305145945035388, False, False],
        ["b", "x", 57.3687951851441, False, False],
        ["a", "x", 59.31395756039643, False, False],
        ["d", "x", 19.557939187075984, False, False],
        ["d", "y", 15.739556224725082, False, False],
        ["c", "x", 6.850626809845993, True, False],
        ["c", "x", 7.774579861173826, False, False],
        ["c", "x", 20.807136344297092, True, False],
        ["b", "y", 29.406207780312343, False, False],
        ["a", "x", 34.38851648220974, False, False],
        ["d", "x", 1.0951104244381218, True, False],
        ["c", "x", 37.27483338042188, False, False],
        ["b", "x", 15.745994603442064, False, False],
        ["c", "x", 17.78180189764816, False, True],
        ["a", "x", 17.067548428231493, True, False],
        ["c", "x", 26.857320012899727, False, False],
        ["a", "x", 41.0038510689549, False, True],
        ["d", "x", 0.2299684913699096, False, True],
        ["a", "x", 57.35885570158893, True, False],
        ["d", "x", 12.40118443712448, False, False],
        ["c", "x", 22.624550487374112, False, False],
        ["a", "x", 93.08587619178269, False, False],
        ["b", "y", 18.33030505634329, False, False],
        ["a", "z", 64.29945681859853, False, True],
        ["b", "x", 73.66024742961967, False, False],
        ["b", "x", 16.717397443478287, False, True],
        ["c", "y", 4.642615342125205, False, True],
        ["c", "x", 9.431345661106931, False, False],
        ["c", "y", 31.76238774237109, False, False],
        ["c", "y", 3.6961806894707965, False, False],
        ["d", "y", 2.298895066631253, True, False],
        ["d", "y", 13.222298172220462, False, False],
        ["c", "x", 28.301638775451153, False, False],
        ["d", "x", 7.702270580869413, True, False],
        ["a", "y", 41.38492280508702, True, False],
        ["d", "x", 13.047815503255656, True, False],
        ["c", "x", 22.14641490202623, False, False],
        ["b", "z", 43.13007970158368, False, True],
        ["b", "x", 60.09518672623882, True, False],
        ["a", "x", 79.6370082234198, False, False],
        ["d", "x", 16.60880504367762, False, False],
        ["a", "z", 22.88783470451029, False, True],
        ["a", "x", 33.66416643964188, False, False],
        ["b", "y", 69.91787304290465, True, True],
        ["c", "x", 31.941092922567663, True, False],
        ["d", "x", 16.739638908154518, False, False],
        ["a", "z", 11.129589373273108, False, True],
        ["d", "y", 4.96943558614434, True, False],
        ["d", "y", 6.585354730457387, False, False],
        ["d", "x", 9.859942318446954, False, False],
        ["b", "z", 18.541485302271496, False, True],
        ["a", "x", 87.53473074574995, True, False],
        ["a", "z", 59.61068083691302, False, True],
    ]

    data = pd.DataFrame(data_arr, columns=["a", "b", "c", "d", "e"])
    return data


@pytest.fixture
def train_data_discrete(train_data) -> pd.DataFrame:
    """
    train_data in discretised form. This maps "c" into 5 buckets:
    - 0: x < 20
    - 1: 20 <= x < 40
    - 2: 40 <= x < 60
    - 3: 60 <= x < 80
    - 4: 80 <= x
    """
    df = train_data.copy(deep=True)  # type: pd.DataFrame
    df["c"] = df["c"].apply(
        lambda c: 0 if c < 20 else 1 if c < 40 else 2 if c < 60 else 3 if c < 80 else 4
    )
    return df


@pytest.fixture
def train_data_idx(train_data) -> pd.DataFrame:
    """
    train_data in integer index form. This maps each column into values from 0..n
    """

    df = train_data.copy(deep=True)  # type: pd.DataFrame

    df["a"] = df["a"].map({"a": 0, "b": 1, "c": 2, "d": 3})
    df["b"] = df["b"].map({"x": 0, "y": 1, "z": 2})
    df["c"] = df["c"].apply(
        lambda c: 0 if c < 20 else 1 if c < 40 else 2 if c < 60 else 3 if c < 80 else 4
    )
    df["d"] = df["d"].map({True: 1, False: 0})
    df["e"] = df["e"].map({True: 1, False: 0})
    return df


@pytest.fixture
def train_data_idx_cpds(train_data_idx) -> Dict[str, np.ndarray]:
    """Conditional probability distributions of train_data in the train_model"""

    return create_cpds(train_data_idx)


@pytest.fixture
def train_data_discrete_cpds(train_data_discrete) -> Dict[str, np.ndarray]:
    """Conditional probability distributions of train_data in the train_model"""

    return create_cpds(train_data_discrete)


@pytest.fixture
def train_data_discrete_cpds_k2(train_data_discrete) -> Dict[str, np.ndarray]:
    """Conditional probability distributions of train_data in the train_model"""

    return create_cpds(train_data_discrete, pc=1)


def create_cpds(data, pc=0):
    df = data.copy(deep=True)  # type: pd.DataFrame

    df_vals = {col: list(df[col].unique()) for col in df.columns}
    for _, vals in df_vals.items():
        vals.sort()

    cpd_a = np.array(
        [
            [
                (len(df[(df["a"] == a) & (df["b"] == b) & (df["d"] == d)]) + pc)
                / (len(df[(df["b"] == b) & (df["d"] == d)]) + (pc * len(df_vals["a"])))
                for b in df_vals["b"]
                for d in df_vals["d"]
            ]
            for a in df_vals["a"]
        ]
    )

    cpd_b = np.array(
        [
            [
                (len(df[(df["b"] == b) & (df["d"] == d) & (df["e"] == e)]) + pc)
                / (len(df[(df["d"] == d) & (df["e"] == e)]) + (pc * len(df_vals["b"])))
                for d in df_vals["d"]
                for e in df_vals["e"]
            ]
            for b in df_vals["b"]
        ]
    )

    cpd_c = np.array(
        [
            [
                (
                    (
                        len(
                            df[
                                (df["c"] == c)
                                & (df["b"] == b)
                                & (df["d"] == d)
                                & (df["e"] == e)
                            ]
                        )
                        + pc
                    )
                    / (
                        len(df[(df["b"] == b) & (df["d"] == d) & (df["e"] == e)])
                        + (pc * len(df_vals["c"]))
                    )
                )
                if not df[(df["b"] == b) & (df["d"] == d) & (df["e"] == e)].empty
                else (1 / len(df_vals["c"]))
                for b in df_vals["b"]
                for d in df_vals["d"]
                for e in df_vals["e"]
            ]
            for c in df_vals["c"]
        ]
    )

    cpd_d = np.array(
        [
            [(len(df[df["d"] == d]) + pc) / (len(df) + (pc * len(df_vals["d"])))]
            for d in df_vals["d"]
        ]
    )

    cpd_e = np.array(
        [
            [(len(df[df["e"] == e]) + pc) / (len(df) + (pc * len(df_vals["e"])))]
            for e in df_vals["e"]
        ]
    )

    return {"a": cpd_a, "b": cpd_b, "c": cpd_c, "d": cpd_d, "e": cpd_e}


@pytest.fixture
def train_data_idx_marginals(train_data_idx_cpds):
    return create_marginals(
        train_data_idx_cpds,
        {
            "a": list(range(4)),
            "b": list(range(3)),
            "c": list(range(5)),
            "d": list(range(2)),
            "e": list(range(2)),
        },
    )


@pytest.fixture
def train_data_discrete_marginals(train_data_discrete_cpds):
    return create_marginals(
        train_data_discrete_cpds,
        {
            "a": ["a", "b", "c", "d"],
            "b": ["x", "y", "z"],
            "c": [0, 1, 2, 3, 4],
            "d": [False, True],
            "e": [False, True],
        },
    )


def create_marginals(cpds, data_vals):
    cpd_d = cpds["d"]
    p_d = {i: cpd_d[i, 0] for i in range(len(cpd_d))}

    cpd_e = cpds["e"]
    p_e = {i: cpd_e[i, 0] for i in range(len(cpd_e))}

    cpd_b = cpds["b"]
    c_b = np.array(
        [
            [p_d[d] * p_e[e] for d in range(len(cpd_d)) for e in range(len(cpd_e))]
            for _ in range(len(cpd_b))
        ]
    )
    p_b = dict(enumerate((c_b * cpd_b).sum(axis=1)))

    cpd_a = cpds["a"]
    c_a = np.array(
        [
            [p_b[b] * p_d[d] for b in range(len(cpd_b)) for d in range(len(cpd_d))]
            for _ in range(len(cpd_a))
        ]
    )
    p_a = dict(enumerate((c_a * cpd_a).sum(axis=1)))

    cpd_c = cpds["c"]
    c_c = np.array(
        [
            [
                p_b[b] * p_d[d] * p_e[e]
                for b in range(len(cpd_b))
                for d in range(len(cpd_d))
                for e in range(len(cpd_e))
            ]
            for _ in range(len(cpd_c))
        ]
    )
    p_c = dict(enumerate((c_c * cpd_c).sum(axis=1)))

    marginals = {
        "a": {data_vals["a"][k]: v for k, v in p_a.items()},
        "b": {data_vals["b"][k]: v for k, v in p_b.items()},
        "c": {data_vals["c"][k]: v for k, v in p_c.items()},
        "d": {data_vals["d"][k]: v for k, v in p_d.items()},
        "e": {data_vals["e"][k]: v for k, v in p_e.items()},
    }

    return marginals


@pytest.fixture
def test_data_c() -> pd.DataFrame:
    """Test data created so that C should be perfectly predicted based on train_data_cpds.

    Given the two independent variables are set randomly (d, e), all other variables are set to be
    from the category with maximum likelihood in train_data_cpds"""

    data_arr = [
        ["a", "x", 1, False, False],
        ["b", "x", 2, False, True],
        ["c", "x", 3, True, False],
        ["d", "x", 4, True, True],
        ["d", "y", 1, False, False],
        ["c", "y", 2, False, True],
        ["b", "y", 23, True, False],
        ["a", "y", 64, True, True],
        ["c", "z", 1, False, False],
        ["a", "z", 2, False, True],
        ["d", "z", 3, True, False],
        ["b", "z", 0, True, True],
    ]

    data = pd.DataFrame(data_arr, columns=["a", "b", "c", "d", "e"])
    return data


@pytest.fixture
def test_data_c_discrete(test_data_c) -> pd.DataFrame:
    """Test data C that has been discretised (see train_data_discrete)"""
    df = test_data_c.copy(deep=True)  # type: pd.DataFrame
    df["c"] = df["c"].apply(
        lambda c: 0 if c < 20 else 1 if c < 40 else 2 if c < 60 else 3 if c < 80 else 4
    )
    return df


@pytest.fixture
def test_data_c_likelihood(train_data_discrete_cpds) -> pd.DataFrame:
    """Marginal likelihoods for train_data in train_model"""

    data_arr = [
        [
            (train_data_discrete_cpds["c"])[y, x]
            for y in range(len(train_data_discrete_cpds["c"]))
        ]
        for x in range(len(train_data_discrete_cpds["c"][0]))
    ]

    likelihood = pd.DataFrame(data_arr, columns=["c_0", "c_1", "c_2", "c_3", "c_4"])
    return likelihood


@pytest.fixture
def bn(train_data_idx, train_data_discrete) -> BayesianNetwork:
    """Perform structure learning and CPD estimation"""
    return BayesianNetwork(
        from_pandas(train_data_idx, w_threshold=0.3)
    ).fit_node_states_and_cpds(train_data_discrete)


@pytest.fixture
def empty_cpd() -> pd.DataFrame:
    """Create an empty CPD table"""
    parents = {"d": {True, False}, "e": {True, False}}
    tuples = tuple(itertools.product(*[tuple(states) for _, states in parents.items()]))
    index = pd.MultiIndex.from_tuples(tuples, names=list(parents.keys()))
    df = pd.DataFrame(np.NaN, index=["x", "y", "z"], columns=index)
    df.index.name = "b"
    return df


@pytest.fixture
def good_cpd(empty_cpd) -> pd.DataFrame:
    """Create a bad CPD table which does not satisfy probability distribution properties"""
    df = empty_cpd
    df.loc[:] = [[0.2, 1.0, 0.4, 0.1], [0.7, 0.0, 0.0, 0.1], [0.1, 0.0, 0.6, 0.8]]
    return df


@pytest.fixture
def bad_cpd(empty_cpd) -> pd.DataFrame:
    """Create a bad CPD table which does not satisfy probability distribution properties"""
    df = empty_cpd
    df.loc[:] = [[0.2, 1.0, 0.4, 0.1], [0.7, 2.0, 3.0, 0.1], [0.3, 1.0, 0.6, 5.8]]
    return df


@pytest.fixture
def parentless_cpd() -> pd.DataFrame:
    """Create a (valid) CPD table for a parentless node"""
    df = pd.DataFrame([[0.3], [0.7]], index=[False, True])
    df.index.name = "e"
    return df


@pytest.fixture()
def data_dynotears_p1() -> Dict[str, np.ndarray]:
    """
    Training data for testing Dynamic Bayesian Networks. Return a time series with 50 time points, with 5 columns
    This data was simulated with te following configurations
    Configurations:
        - data points 50,
        - num. variables: 5,
        - p (lag amount): 1,
        - graph type (intra-slice graph): 'erdos-renyi',
        - graph type (inter-slice graph): 'erdos-renyi',
        - SEM type: 'linear-gauss',
        - weight range, intra-slice graph: (0.5, 2.0),
        - weight range, inter-slice graph: (0.3, 0.5),
        - expected degree, inter-slice graph: 3,
        - noise scale (gaussian noise): 1.0,
        - w decay: 1.1
    Returns:
        dictionary with keys W (intra-weights), A (inter-weights), X and Y (inputs of from_numpy_dynamic)
    """
    data = {
        "W": np.array(
            [
                [0.0, -0.55, 0.0, 1.48, 0.0],
                [0.0, 0.0, -0.99, 0.0, 0.0],
                [0.0, 0.0, 0.0, -1.13, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, -1.91, -0.64, -1.31, 0.0],
            ]
        ),
        "A": np.array(
            [
                [-0.35, -0.32, -0.33, 0.37, 0.0],
                [-0.41, -0.42, -0.36, 0.33, -0.35],
                [0.46, 0.0, 0.44, -0.36, -0.38],
                [0.0, 0.0, -0.45, 0.0, 0.43],
                [0.31, 0.4, 0.0, 0.0, -0.44],
            ]
        ),
        "X": np.array(
            [
                [-8.7, -3.1, -5.1, 2.8, -2.0],
                [0.3, -5.8, 2.9, -12.6, 5.1],
                [4.8, 17.4, -4.2, 17.6, -6.8],
                [-12.5, -20.6, -1.9, -18.5, 8.3],
                [14.4, 14.1, 7.7, 6.8, -3.0],
                [-8.1, 0.1, -8.6, 9.0, -3.3],
                [-3.3, -11.4, 1.3, -18.4, 8.0],
                [9.4, 14.4, 3.7, 11.4, -6.5],
                [-9.7, -7.0, -3.5, -3.2, -0.0],
                [1.7, -2.7, 5.8, -14.8, 4.5],
                [3.9, 17.2, -2.4, 19.2, -9.8],
                [-12.5, -22.2, 1.2, -23.2, 7.8],
                [17.2, 18.2, 9.2, 11.6, -6.4],
                [-12.2, -7.9, -2.8, -2.6, -1.2],
                [5.8, -4.9, 9.0, -12.6, 3.9],
                [6.0, 13.5, 2.3, 15.1, -8.4],
                [-10.1, -12.8, -1.4, -13.4, 4.5],
                [9.7, 10.3, 5.5, 7.1, -3.5],
                [-7.2, -1.4, -4.5, 0.2, -1.3],
                [1.8, -3.7, 1.6, -4.8, 3.2],
                [2.6, 6.7, -2.1, 9.7, -2.7],
                [-4.6, -7.6, -1.7, -4.4, 2.6],
                [4.6, -0.1, 4.5, -4.1, 1.9],
                [3.3, 5.1, -0.7, 11.0, -4.1],
                [-4.3, -12.7, -0.7, -9.5, 5.9],
                [8.0, 6.0, 5.6, 0.8, -1.1],
                [-2.4, 4.9, -3.9, 11.3, -4.7],
                [-4.6, -12.0, -0.3, -13.4, 6.5],
                [8.1, 15.9, -0.8, 16.0, -6.5],
                [-11.5, -14.8, -5.0, -7.5, 4.4],
                [10.0, 3.3, 7.1, -2.8, 1.7],
                [-0.7, 3.8, -1.8, 9.2, -4.0],
                [-1.6, -13.7, 5.3, -13.0, 4.7],
                [11.0, 8.2, 7.7, 6.7, -4.0],
                [-4.0, -2.4, -3.0, 1.8, -1.7],
                [1.1, -4.8, 2.9, -5.9, 3.0],
                [3.4, 7.5, -1.5, 9.0, -2.4],
                [-6.2, -3.4, -4.8, -2.2, 1.0],
                [0.8, -0.3, 1.8, -2.5, 1.0],
                [0.4, 1.7, -1.6, 1.9, -0.5],
                [-2.3, -4.5, 2.9, -7.4, 2.3],
                [3.9, 9.3, 2.4, 4.1, -3.8],
                [-4.8, -0.6, -3.6, 3.0, -2.1],
                [-1.8, -7.5, 2.9, -13.2, 4.7],
                [5.6, 19.0, -3.7, 18.1, -7.8],
                [-13.5, -19.3, -2.7, -17.7, 7.5],
                [15.5, 10.6, 9.9, 5.7, -3.1],
                [-6.8, 3.3, -6.8, 9.6, -4.8],
                [-5.2, -15.3, 3.9, -21.6, 7.9],
                [11.1, 22.0, -0.3, 20.7, -8.4],
            ]
        ),
        "Y": np.array(
            [
                [10.8, 14.6, 5.8, 8.1, -4.5],
                [-8.7, -3.1, -5.1, 2.8, -2.0],
                [0.3, -5.8, 2.9, -12.6, 5.1],
                [4.8, 17.4, -4.2, 17.6, -6.8],
                [-12.5, -20.6, -1.9, -18.5, 8.3],
                [14.4, 14.1, 7.7, 6.8, -3.0],
                [-8.1, 0.1, -8.6, 9.0, -3.3],
                [-3.3, -11.4, 1.3, -18.4, 8.0],
                [9.4, 14.4, 3.7, 11.4, -6.5],
                [-9.7, -7.0, -3.5, -3.2, -0.0],
                [1.7, -2.7, 5.8, -14.8, 4.5],
                [3.9, 17.2, -2.4, 19.2, -9.8],
                [-12.5, -22.2, 1.2, -23.2, 7.8],
                [17.2, 18.2, 9.2, 11.6, -6.4],
                [-12.2, -7.9, -2.8, -2.6, -1.2],
                [5.8, -4.9, 9.0, -12.6, 3.9],
                [6.0, 13.5, 2.3, 15.1, -8.4],
                [-10.1, -12.8, -1.4, -13.4, 4.5],
                [9.7, 10.3, 5.5, 7.1, -3.5],
                [-7.2, -1.4, -4.5, 0.2, -1.3],
                [1.8, -3.7, 1.6, -4.8, 3.2],
                [2.6, 6.7, -2.1, 9.7, -2.7],
                [-4.6, -7.6, -1.7, -4.4, 2.6],
                [4.6, -0.1, 4.5, -4.1, 1.9],
                [3.3, 5.1, -0.7, 11.0, -4.1],
                [-4.3, -12.7, -0.7, -9.5, 5.9],
                [8.0, 6.0, 5.6, 0.8, -1.1],
                [-2.4, 4.9, -3.9, 11.3, -4.7],
                [-4.6, -12.0, -0.3, -13.4, 6.5],
                [8.1, 15.9, -0.8, 16.0, -6.5],
                [-11.5, -14.8, -5.0, -7.5, 4.4],
                [10.0, 3.3, 7.1, -2.8, 1.7],
                [-0.7, 3.8, -1.8, 9.2, -4.0],
                [-1.6, -13.7, 5.3, -13.0, 4.7],
                [11.0, 8.2, 7.7, 6.7, -4.0],
                [-4.0, -2.4, -3.0, 1.8, -1.7],
                [1.1, -4.8, 2.9, -5.9, 3.0],
                [3.4, 7.5, -1.5, 9.0, -2.4],
                [-6.2, -3.4, -4.8, -2.2, 1.0],
                [0.8, -0.3, 1.8, -2.5, 1.0],
                [0.4, 1.7, -1.6, 1.9, -0.5],
                [-2.3, -4.5, 2.9, -7.4, 2.3],
                [3.9, 9.3, 2.4, 4.1, -3.8],
                [-4.8, -0.6, -3.6, 3.0, -2.1],
                [-1.8, -7.5, 2.9, -13.2, 4.7],
                [5.6, 19.0, -3.7, 18.1, -7.8],
                [-13.5, -19.3, -2.7, -17.7, 7.5],
                [15.5, 10.6, 9.9, 5.7, -3.1],
                [-6.8, 3.3, -6.8, 9.6, -4.8],
                [-5.2, -15.3, 3.9, -21.6, 7.9],
            ]
        ),
    }
    return data


@pytest.fixture()
def data_dynotears_p2() -> Dict[str, np.ndarray]:
    """
    Training data for testing Dynamic Bayesian Networks. Return a time series with 50 time points, with 5 columns
    This data was simulated with te following configurations
    Configurations:
        - data points 50,
        - num. variables: 5,
        - p (lag amount): 2,
        - graph type (intra-slice graph): 'erdos-renyi',
        - graph type (inter-slice graph): 'erdos-renyi',
        - SEM type: 'linear-gauss',
        - weight range, intra-slice graph: (0.5, 2.0),
        - weight range, inter-slice graph: (0.3, 0.5),
        - expected degree, inter-slice graph: 3,
        - noise scale (gaussian noise): 1.0,
        - w decay: 1.1
    Returns:
        dictionary with keys W (intra-weights) ,A (inter-weights), X and Y (inputs of from_numpy_dynamic)
    """
    data = {
        "W": np.array(
            [
                [0.0, 0.0, 0.0, 0.0, -1.08],
                [1.16, 0.0, 0.0, -0.81, 0.89],
                [-1.83, 0.58, 0.0, 0.61, -1.31],
                [1.03, 0.0, 0.0, 0.0, -0.97],
                [0.0, 0.0, 0.0, 0.0, 0.0],
            ]
        ),
        "A": np.array(
            [
                [0.31, 0.0, 0.0, 0.32, 0.0],
                [0.0, 0.36, 0.0, 0.0, 0.5],
                [-0.43, 0.0, 0.0, -0.49, -0.39],
                [0.39, 0.0, 0.39, 0.0, 0.38],
                [0.0, 0.37, 0.0, 0.43, 0.0],
                [0.0, -0.37, 0.34, 0.0, 0.0],
                [-0.3, 0.0, -0.42, 0.0, 0.45],
                [0.0, -0.34, 0.0, 0.0, 0.0],
                [0.0, 0.31, 0.0, 0.29, 0.36],
                [-0.43, 0.37, 0.0, 0.0, -0.34],
            ]
        ),
        "X": np.array(
            [
                [3.1, 0.9, 1.6, 2.9, -6.5],
                [-5.6, -4.1, 3.9, 4.7, -3.4],
                [-6.3, -3.5, 2.2, 1.4, 0.4],
                [-0.3, 0.2, -0.5, -1.5, 2.0],
                [2.4, 1.6, -0.9, -0.7, -0.9],
                [0.6, -0.1, 0.2, 1.5, -2.7],
                [-2.4, -1.6, 2.2, 1.3, -1.6],
                [-4.6, 0.5, 1.4, -2.6, 6.9],
                [-1.3, -0.1, -1.3, -1.7, 3.7],
                [0.5, 4.7, -2.0, -4.1, 6.8],
                [8.8, 4.9, -2.7, -1.1, -0.3],
                [3.8, 3.1, -0.4, 0.1, 1.5],
                [2.0, -0.9, 0.1, 3.9, -1.6],
                [-3.8, -0.9, 2.7, 1.3, 2.4],
                [-4.4, 3.2, 2.1, -3.3, 9.9],
                [0.2, 4.8, -1.6, -3.7, 9.7],
                [5.1, 2.8, -6.4, -1.5, 3.7],
                [8.1, 4.3, -2.4, -0.1, -1.6],
                [2.1, 0.9, 1.5, 2.9, -4.1],
                [-4.1, -3.9, 0.2, 1.1, 4.1],
                [0.3, -3.0, 0.2, 3.0, -4.7],
                [-1.2, 2.8, 2.7, -0.9, -0.9],
                [-0.8, 1.8, 1.7, -0.7, 2.4],
                [-0.2, -2.2, -1.8, 2.6, -0.3],
                [-0.9, 2.1, 0.8, -2.3, 5.9],
                [1.5, 3.6, 0.4, -0.1, -0.1],
                [-3.0, 1.3, -1.6, -3.2, 9.5],
                [6.8, -0.1, -3.5, 2.5, -3.2],
                [3.5, 1.8, -1.7, -2.1, 0.5],
                [5.1, 1.0, 1.6, 3.6, -6.5],
                [-2.8, -3.0, 1.4, 1.8, -3.5],
                [-5.9, -4.8, 3.2, 4.0, -3.6],
                [-6.7, -2.2, 1.3, -2.7, 4.8],
                [0.4, 0.0, -1.8, -0.9, 0.4],
                [3.2, 1.7, -2.9, -3.6, 3.1],
                [5.7, 0.9, -2.9, -0.3, -1.4],
                [3.4, -2.0, -0.5, 3.5, -9.1],
                [-4.3, -4.4, 2.9, 2.6, -5.2],
                [-9.7, -6.4, 3.4, 1.9, -1.4],
                [-5.2, -1.4, 1.6, -2.3, 1.0],
                [-1.1, 1.1, -1.6, -3.5, 4.1],
                [0.9, 2.0, -1.6, -2.6, 3.3],
                [4.3, 4.8, -0.7, -0.7, 0.3],
                [4.9, -0.4, -1.3, 1.8, -3.9],
                [1.3, -2.9, -0.8, 2.0, -1.3],
                [-0.1, -0.8, 2.9, 3.1, -6.7],
                [-5.8, 0.5, 4.0, -1.2, 2.6],
                [-1.8, -2.4, -2.0, -2.4, 3.5],
                [3.3, 0.7, -3.0, -0.6, 0.0],
                [4.5, -0.1, 0.4, 2.9, -10.8],
            ]
        ),
        "Y": np.array(
            [
                [6.1, 0.5, -2.0, 0.9, -3.0, 2.9, 1.5, -0.9, -0.2, 0.8],
                [3.1, 0.9, 1.6, 2.9, -6.5, 6.1, 0.5, -2.0, 0.9, -3.0],
                [-5.6, -4.1, 3.9, 4.7, -3.4, 3.1, 0.9, 1.6, 2.9, -6.5],
                [-6.3, -3.5, 2.2, 1.4, 0.4, -5.6, -4.1, 3.9, 4.7, -3.4],
                [-0.3, 0.2, -0.5, -1.5, 2.0, -6.3, -3.5, 2.2, 1.4, 0.4],
                [2.4, 1.6, -0.9, -0.7, -0.9, -0.3, 0.2, -0.5, -1.5, 2.0],
                [0.6, -0.1, 0.2, 1.5, -2.7, 2.4, 1.6, -0.9, -0.7, -0.9],
                [-2.4, -1.6, 2.2, 1.3, -1.6, 0.6, -0.1, 0.2, 1.5, -2.7],
                [-4.6, 0.5, 1.4, -2.6, 6.9, -2.4, -1.6, 2.2, 1.3, -1.6],
                [-1.3, -0.1, -1.3, -1.7, 3.7, -4.6, 0.5, 1.4, -2.6, 6.9],
                [0.5, 4.7, -2.0, -4.1, 6.8, -1.3, -0.1, -1.3, -1.7, 3.7],
                [8.8, 4.9, -2.7, -1.1, -0.3, 0.5, 4.7, -2.0, -4.1, 6.8],
                [3.8, 3.1, -0.4, 0.1, 1.5, 8.8, 4.9, -2.7, -1.1, -0.3],
                [2.0, -0.9, 0.1, 3.9, -1.6, 3.8, 3.1, -0.4, 0.1, 1.5],
                [-3.8, -0.9, 2.7, 1.3, 2.4, 2.0, -0.9, 0.1, 3.9, -1.6],
                [-4.4, 3.2, 2.1, -3.3, 9.9, -3.8, -0.9, 2.7, 1.3, 2.4],
                [0.2, 4.8, -1.6, -3.7, 9.7, -4.4, 3.2, 2.1, -3.3, 9.9],
                [5.1, 2.8, -6.4, -1.5, 3.7, 0.2, 4.8, -1.6, -3.7, 9.7],
                [8.1, 4.3, -2.4, -0.1, -1.6, 5.1, 2.8, -6.4, -1.5, 3.7],
                [2.1, 0.9, 1.5, 2.9, -4.1, 8.1, 4.3, -2.4, -0.1, -1.6],
                [-4.1, -3.9, 0.2, 1.1, 4.1, 2.1, 0.9, 1.5, 2.9, -4.1],
                [0.3, -3.0, 0.2, 3.0, -4.7, -4.1, -3.9, 0.2, 1.1, 4.1],
                [-1.2, 2.8, 2.7, -0.9, -0.9, 0.3, -3.0, 0.2, 3.0, -4.7],
                [-0.8, 1.8, 1.7, -0.7, 2.4, -1.2, 2.8, 2.7, -0.9, -0.9],
                [-0.2, -2.2, -1.8, 2.6, -0.3, -0.8, 1.8, 1.7, -0.7, 2.4],
                [-0.9, 2.1, 0.8, -2.3, 5.9, -0.2, -2.2, -1.8, 2.6, -0.3],
                [1.5, 3.6, 0.4, -0.1, -0.1, -0.9, 2.1, 0.8, -2.3, 5.9],
                [-3.0, 1.3, -1.6, -3.2, 9.5, 1.5, 3.6, 0.4, -0.1, -0.1],
                [6.8, -0.1, -3.5, 2.5, -3.2, -3.0, 1.3, -1.6, -3.2, 9.5],
                [3.5, 1.8, -1.7, -2.1, 0.5, 6.8, -0.1, -3.5, 2.5, -3.2],
                [5.1, 1.0, 1.6, 3.6, -6.5, 3.5, 1.8, -1.7, -2.1, 0.5],
                [-2.8, -3.0, 1.4, 1.8, -3.5, 5.1, 1.0, 1.6, 3.6, -6.5],
                [-5.9, -4.8, 3.2, 4.0, -3.6, -2.8, -3.0, 1.4, 1.8, -3.5],
                [-6.7, -2.2, 1.3, -2.7, 4.8, -5.9, -4.8, 3.2, 4.0, -3.6],
                [0.4, 0.0, -1.8, -0.9, 0.4, -6.7, -2.2, 1.3, -2.7, 4.8],
                [3.2, 1.7, -2.9, -3.6, 3.1, 0.4, 0.0, -1.8, -0.9, 0.4],
                [5.7, 0.9, -2.9, -0.3, -1.4, 3.2, 1.7, -2.9, -3.6, 3.1],
                [3.4, -2.0, -0.5, 3.5, -9.1, 5.7, 0.9, -2.9, -0.3, -1.4],
                [-4.3, -4.4, 2.9, 2.6, -5.2, 3.4, -2.0, -0.5, 3.5, -9.1],
                [-9.7, -6.4, 3.4, 1.9, -1.4, -4.3, -4.4, 2.9, 2.6, -5.2],
                [-5.2, -1.4, 1.6, -2.3, 1.0, -9.7, -6.4, 3.4, 1.9, -1.4],
                [-1.1, 1.1, -1.6, -3.5, 4.1, -5.2, -1.4, 1.6, -2.3, 1.0],
                [0.9, 2.0, -1.6, -2.6, 3.3, -1.1, 1.1, -1.6, -3.5, 4.1],
                [4.3, 4.8, -0.7, -0.7, 0.3, 0.9, 2.0, -1.6, -2.6, 3.3],
                [4.9, -0.4, -1.3, 1.8, -3.9, 4.3, 4.8, -0.7, -0.7, 0.3],
                [1.3, -2.9, -0.8, 2.0, -1.3, 4.9, -0.4, -1.3, 1.8, -3.9],
                [-0.1, -0.8, 2.9, 3.1, -6.7, 1.3, -2.9, -0.8, 2.0, -1.3],
                [-5.8, 0.5, 4.0, -1.2, 2.6, -0.1, -0.8, 2.9, 3.1, -6.7],
                [-1.8, -2.4, -2.0, -2.4, 3.5, -5.8, 0.5, 4.0, -1.2, 2.6],
                [3.3, 0.7, -3.0, -0.6, 0.0, -1.8, -2.4, -2.0, -2.4, 3.5],
            ]
        ),
    }
    return data


@pytest.fixture()
def data_dynotears_p3() -> Dict[str, np.ndarray]:
    """
    Training data for testing Dynamic Bayesian Networks. Return a time series with 50 time points, with 5 columns
    This data was simulated with te following configurations.
    Configurations:
        - data points 50,
        - num. variables: 5,
        - p (lag amount): 3,
        - graph type (intra-slice graph): 'erdos-renyi',
        - graph type (inter-slice graph): 'erdos-renyi',
        - SEM type: 'linear-gauss',
        - weight range, intra-slice graph: (0.5, 2.0),
        - weight range, inter-slice graph: (0.3, 0.5),
        - expected degree, inter-slice graph: 3,
        - noise scale (gaussian noise): 1.0,
        - w decay: 1.1
    Returns:
        dictionary with keys W (intra-weights), A (inter-weights), X and Y (inputs of from_numpy_dynamic)
    """
    data = {
        "W": np.array(
            [
                [0.0, 0.0, -1.18, 0.0, -0.92],
                [0.0, 0.0, -1.71, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, -1.15],
                [0.0, 0.8, 0.0, 0.0, -1.65],
                [0.0, 0.0, 0.0, 0.0, 0.0],
            ]
        ),
        "A": np.array(
            [
                [-0.42, -0.4, 0.35, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [-0.38, 0.0, -0.48, -0.48, 0.44],
                [0.0, -0.36, -0.31, 0.35, 0.0],
                [0.0, 0.0, -0.35, 0.0, -0.45],
                [0.27, 0.0, 0.0, -0.28, -0.43],
                [-0.44, 0.35, 0.0, 0.0, -0.44],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [-0.42, -0.39, -0.27, 0.3, -0.29],
                [0.0, -0.43, 0.0, -0.41, 0.4],
                [0.0, 0.32, 0.38, -0.37, 0.37],
                [0.0, 0.37, 0.0, 0.0, 0.41],
                [0.0, 0.34, 0.0, 0.27, -0.28],
                [-0.31, 0.0, 0.0, -0.34, 0.0],
                [0.27, -0.38, 0.0, 0.37, 0.36],
            ]
        ),
        "X": np.array(
            [
                [-21.5, 26.1, -56.8, -15.4, 82.4],
                [14.4, 20.3, -51.3, -18.8, -4.6],
                [16.5, -17.0, 48.2, 18.7, -100.2],
                [5.1, -38.8, 70.5, 5.1, 8.6],
                [-19.4, 20.5, -47.8, -4.3, 86.9],
                [9.8, 28.7, -69.5, -18.3, 8.2],
                [9.5, -7.1, 42.4, 13.3, -107.0],
                [2.8, -42.5, 83.7, 1.5, 3.1],
                [-25.7, 18.6, -42.4, -8.5, 112.6],
                [11.5, 34.3, -94.2, -20.5, 33.3],
                [19.0, -9.7, 39.5, 18.2, -137.5],
                [7.9, -60.4, 122.3, 8.1, -33.3],
                [-32.5, 6.0, -19.4, -11.0, 130.2],
                [3.5, 48.0, -127.1, -33.7, 81.8],
                [28.4, 7.6, 1.0, 16.3, -146.2],
                [20.0, -71.8, 153.1, 17.9, -93.3],
                [-34.3, -13.0, 20.5, -4.4, 130.2],
                [-9.4, 58.4, -150.8, -41.9, 141.6],
                [26.7, 32.9, -49.8, 7.9, -134.1],
                [34.3, -74.1, 161.6, 26.1, -152.1],
                [-34.2, -35.1, 75.5, 8.0, 105.1],
                [-21.9, 54.5, -151.3, -49.8, 199.9],
                [19.2, 54.9, -98.1, -8.1, -95.3],
                [47.4, -72.0, 158.9, 22.4, -192.6],
                [-25.7, -60.3, 129.5, 21.9, 54.6],
                [-26.9, 42.2, -136.6, -48.4, 243.1],
                [12.6, 79.6, -153.1, -11.6, -57.4],
                [54.1, -66.7, 153.3, 17.7, -241.5],
                [-26.1, -79.5, 185.0, 25.9, 19.5],
                [-34.8, 36.3, -133.8, -54.2, 311.0],
                [10.6, 105.3, -219.7, -19.1, -15.3],
                [73.9, -58.5, 138.7, 25.1, -307.4],
                [-17.7, -107.1, 257.4, 47.9, -58.4],
                [-51.3, 17.0, -93.4, -59.1, 354.7],
                [-13.4, 140.4, -287.4, -43.0, 83.7],
                [85.8, -28.4, 71.1, 11.6, -310.8],
                [4.5, -130.3, 310.9, 68.7, -171.8],
                [-51.1, -25.3, -12.3, -46.8, 361.3],
                [-34.2, 158.2, -341.9, -48.7, 191.1],
                [92.6, 7.7, -9.7, -2.9, -314.2],
                [20.2, -143.1, 363.2, 83.7, -294.6],
                [-54.6, -69.6, 82.3, -36.4, 356.7],
                [-60.4, 176.2, -390.3, -61.8, 319.4],
                [97.2, 53.8, -116.7, -18.9, -289.0],
                [46.0, -152.8, 400.1, 103.7, -440.8],
                [-52.1, -124.2, 206.7, -16.7, 307.3],
                [-93.8, 183.7, -415.9, -74.5, 465.6],
                [92.4, 118.4, -265.0, -47.7, -202.0],
                [76.5, -140.2, 396.2, 116.8, -592.8],
                [-36.6, -195.8, 366.6, 8.3, 202.5],
            ]
        ),
        "Y_1": np.array(
            [
                [-4.9, -31.2, 50.6, -5.4, 42.7, 11.0, -26.3, 66.7, 21.5, -82.3],
                [-21.5, 26.1, -56.8, -15.4, 82.4, -4.9, -31.2, 50.6, -5.4, 42.7],
                [14.4, 20.3, -51.3, -18.8, -4.6, -21.5, 26.1, -56.8, -15.4, 82.4],
                [16.5, -17.0, 48.2, 18.7, -100.2, 14.4, 20.3, -51.3, -18.8, -4.6],
                [5.1, -38.8, 70.5, 5.1, 8.6, 16.5, -17.0, 48.2, 18.7, -100.2],
                [-19.4, 20.5, -47.8, -4.3, 86.9, 5.1, -38.8, 70.5, 5.1, 8.6],
                [9.8, 28.7, -69.5, -18.3, 8.2, -19.4, 20.5, -47.8, -4.3, 86.9],
                [9.5, -7.1, 42.4, 13.3, -107.0, 9.8, 28.7, -69.5, -18.3, 8.2],
                [2.8, -42.5, 83.7, 1.5, 3.1, 9.5, -7.1, 42.4, 13.3, -107.0],
                [-25.7, 18.6, -42.4, -8.5, 112.6, 2.8, -42.5, 83.7, 1.5, 3.1],
                [11.5, 34.3, -94.2, -20.5, 33.3, -25.7, 18.6, -42.4, -8.5, 112.6],
                [19.0, -9.7, 39.5, 18.2, -137.5, 11.5, 34.3, -94.2, -20.5, 33.3],
                [7.9, -60.4, 122.3, 8.1, -33.3, 19.0, -9.7, 39.5, 18.2, -137.5],
                [-32.5, 6.0, -19.4, -11.0, 130.2, 7.9, -60.4, 122.3, 8.1, -33.3],
                [3.5, 48.0, -127.1, -33.7, 81.8, -32.5, 6.0, -19.4, -11.0, 130.2],
                [28.4, 7.6, 1.0, 16.3, -146.2, 3.5, 48.0, -127.1, -33.7, 81.8],
                [20.0, -71.8, 153.1, 17.9, -93.3, 28.4, 7.6, 1.0, 16.3, -146.2],
                [-34.3, -13.0, 20.5, -4.4, 130.2, 20.0, -71.8, 153.1, 17.9, -93.3],
                [-9.4, 58.4, -150.8, -41.9, 141.6, -34.3, -13.0, 20.5, -4.4, 130.2],
                [26.7, 32.9, -49.8, 7.9, -134.1, -9.4, 58.4, -150.8, -41.9, 141.6],
                [34.3, -74.1, 161.6, 26.1, -152.1, 26.7, 32.9, -49.8, 7.9, -134.1],
                [-34.2, -35.1, 75.5, 8.0, 105.1, 34.3, -74.1, 161.6, 26.1, -152.1],
                [-21.9, 54.5, -151.3, -49.8, 199.9, -34.2, -35.1, 75.5, 8.0, 105.1],
                [19.2, 54.9, -98.1, -8.1, -95.3, -21.9, 54.5, -151.3, -49.8, 199.9],
                [47.4, -72.0, 158.9, 22.4, -192.6, 19.2, 54.9, -98.1, -8.1, -95.3],
                [-25.7, -60.3, 129.5, 21.9, 54.6, 47.4, -72.0, 158.9, 22.4, -192.6],
                [-26.9, 42.2, -136.6, -48.4, 243.1, -25.7, -60.3, 129.5, 21.9, 54.6],
                [12.6, 79.6, -153.1, -11.6, -57.4, -26.9, 42.2, -136.6, -48.4, 243.1],
                [54.1, -66.7, 153.3, 17.7, -241.5, 12.6, 79.6, -153.1, -11.6, -57.4],
                [-26.1, -79.5, 185.0, 25.9, 19.5, 54.1, -66.7, 153.3, 17.7, -241.5],
                [-34.8, 36.3, -133.8, -54.2, 311.0, -26.1, -79.5, 185.0, 25.9, 19.5],
                [10.6, 105.3, -219.7, -19.1, -15.3, -34.8, 36.3, -133.8, -54.2, 311.0],
                [73.9, -58.5, 138.7, 25.1, -307.4, 10.6, 105.3, -219.7, -19.1, -15.3],
                [-17.7, -107.1, 257.4, 47.9, -58.4, 73.9, -58.5, 138.7, 25.1, -307.4],
                [-51.3, 17.0, -93.4, -59.1, 354.7, -17.7, -107.1, 257.4, 47.9, -58.4],
                [-13.4, 140.4, -287.4, -43.0, 83.7, -51.3, 17.0, -93.4, -59.1, 354.7],
                [85.8, -28.4, 71.1, 11.6, -310.8, -13.4, 140.4, -287.4, -43.0, 83.7],
                [4.5, -130.3, 310.9, 68.7, -171.8, 85.8, -28.4, 71.1, 11.6, -310.8],
                [-51.1, -25.3, -12.3, -46.8, 361.3, 4.5, -130.3, 310.9, 68.7, -171.8],
                [-34.2, 158.2, -341.9, -48.7, 191.1, -51.1, -25.3, -12.3, -46.8, 361.3],
                [92.6, 7.7, -9.7, -2.9, -314.2, -34.2, 158.2, -341.9, -48.7, 191.1],
                [20.2, -143.1, 363.2, 83.7, -294.6, 92.6, 7.7, -9.7, -2.9, -314.2],
                [-54.6, -69.6, 82.3, -36.4, 356.7, 20.2, -143.1, 363.2, 83.7, -294.6],
                [-60.4, 176.2, -390.3, -61.8, 319.4, -54.6, -69.6, 82.3, -36.4, 356.7],
                [97.2, 53.8, -116.7, -18.9, -289.0, -60.4, 176.2, -390.3, -61.8, 319.4],
                [46.0, -152.8, 400.1, 103.7, -440.8, 97.2, 53.8, -116.7, -18.9, -289.0],
                [-52.1, -124.2, 206.7, -16.7, 307.3, 46, -152.8, 400.1, 103.7, -440.8],
                [-93.8, 183.7, -415.9, -74.5, 465.6, -52, -124.2, 206.7, -16.7, 307.3],
                [92.4, 118.4, -265.0, -47.7, -202, -93.8, 183.7, -415.9, -74.5, 465.6],
                [76.5, -140.2, 396.2, 116.8, -592.8, 92.4, 118.4, -265, -47.7, -202],
            ]
        ),
        "Y_2": np.array(
            [
                [22.4, 6.9, -22.7, -6.3, -40.2],
                [11.0, -26.3, 66.7, 21.5, -82.3],
                [-4.9, -31.2, 50.6, -5.4, 42.7],
                [-21.5, 26.1, -56.8, -15.4, 82.4],
                [14.4, 20.3, -51.3, -18.8, -4.6],
                [16.5, -17.0, 48.2, 18.7, -100.2],
                [5.1, -38.8, 70.5, 5.1, 8.6],
                [-19.4, 20.5, -47.8, -4.3, 86.9],
                [9.8, 28.7, -69.5, -18.3, 8.2],
                [9.5, -7.1, 42.4, 13.3, -107.0],
                [2.8, -42.5, 83.7, 1.5, 3.1],
                [-25.7, 18.6, -42.4, -8.5, 112.6],
                [11.5, 34.3, -94.2, -20.5, 33.3],
                [19.0, -9.7, 39.5, 18.2, -137.5],
                [7.9, -60.4, 122.3, 8.1, -33.3],
                [-32.5, 6.0, -19.4, -11.0, 130.2],
                [3.5, 48.0, -127.1, -33.7, 81.8],
                [28.4, 7.6, 1.0, 16.3, -146.2],
                [20.0, -71.8, 153.1, 17.9, -93.3],
                [-34.3, -13.0, 20.5, -4.4, 130.2],
                [-9.4, 58.4, -150.8, -41.9, 141.6],
                [26.7, 32.9, -49.8, 7.9, -134.1],
                [34.3, -74.1, 161.6, 26.1, -152.1],
                [-34.2, -35.1, 75.5, 8.0, 105.1],
                [-21.9, 54.5, -151.3, -49.8, 199.9],
                [19.2, 54.9, -98.1, -8.1, -95.3],
                [47.4, -72.0, 158.9, 22.4, -192.6],
                [-25.7, -60.3, 129.5, 21.9, 54.6],
                [-26.9, 42.2, -136.6, -48.4, 243.1],
                [12.6, 79.6, -153.1, -11.6, -57.4],
                [54.1, -66.7, 153.3, 17.7, -241.5],
                [-26.1, -79.5, 185.0, 25.9, 19.5],
                [-34.8, 36.3, -133.8, -54.2, 311.0],
                [10.6, 105.3, -219.7, -19.1, -15.3],
                [73.9, -58.5, 138.7, 25.1, -307.4],
                [-17.7, -107.1, 257.4, 47.9, -58.4],
                [-51.3, 17.0, -93.4, -59.1, 354.7],
                [-13.4, 140.4, -287.4, -43.0, 83.7],
                [85.8, -28.4, 71.1, 11.6, -310.8],
                [4.5, -130.3, 310.9, 68.7, -171.8],
                [-51.1, -25.3, -12.3, -46.8, 361.3],
                [-34.2, 158.2, -341.9, -48.7, 191.1],
                [92.6, 7.7, -9.7, -2.9, -314.2],
                [20.2, -143.1, 363.2, 83.7, -294.6],
                [-54.6, -69.6, 82.3, -36.4, 356.7],
                [-60.4, 176.2, -390.3, -61.8, 319.4],
                [97.2, 53.8, -116.7, -18.9, -289.0],
                [46.0, -152.8, 400.1, 103.7, -440.8],
                [-52.1, -124.2, 206.7, -16.7, 307.3],
                [-93.8, 183.7, -415.9, -74.5, 465.6],
            ]
        ),
    }
    data["Y"] = np.array(
        [list(y1) + list(y2) for y1, y2 in zip(data["Y_1"], data["Y_2"])]
    )
    del data["Y_1"]
    del data["Y_2"]
    return data


@pytest.fixture
def adjacency_mat_num_stability() -> np.ndarray:
    """
    Adjacency matrix for training structure learning algorithms
    """
    W = np.array(
        [
            [0.0, 0.0, 0.0, 0.0, 0.0],
            [-0.6, 0.0, 0.0, 0.0, 1.27],
            [0.9, 0.0, 0.0, 0.0, -0.98],
            [0.0, -0.89, 1.37, 0.0, 0.0],
            [1.74, 0.0, 0.0, 0.0, 0.0],
        ]
    )
    return W


@pytest.fixture
def iris_test_data() -> pd.DataFrame:
    """
    Iris dataset to test sklearn wrappers
    """
    iris = load_iris()
    X, y = iris["data"], iris["target"]
    names = iris["feature_names"]
    df = pd.DataFrame(X, columns=names)
    df["type"] = y
    df["sepal length (cm)"] = Discretiser(
        method="quantile", num_buckets=3
    ).fit_transform(df["sepal length (cm)"].values)
    return df


@pytest.fixture
def iris_edge_list():
    """
    Edge list to construct bayesian network for iris data
    """
    edge_list = [
        ("sepal width (cm)", "sepal length (cm)"),
        ("petal length (cm)", "sepal length (cm)"),
        ("petal length (cm)", "sepal width (cm)"),
        ("petal width (cm)", "petal length (cm)"),
        ("type", "sepal width (cm)"),
        ("type", "petal width (cm)"),
    ]
    return edge_list


@pytest.fixture
def chain_network() -> BayesianNetwork:
    """
    This Bayesian Model structure to test do interventions that split graph
    into subgraphs.

    a → b → c → d → e
    """
    n = 50
    nodes_names = list("abcde")
    random_binary_matrix = (
        np.random.randint(10, size=(n, len(nodes_names))) > 6
    ).astype(int)
    df = pd.DataFrame(data=random_binary_matrix, columns=nodes_names)

    model = StructureModel()
    model.add_edges_from(
        [
            ("a", "b"),
            ("b", "c"),
            ("c", "d"),
            ("d", "e"),
        ]
    )
    chain_bn = BayesianNetwork(model)
    chain_bn = chain_bn.fit_node_states(df)
    chain_bn = chain_bn.fit_cpds(df, method="BayesianEstimator", bayes_prior="K2")
    return chain_bn
